#include <sodium.h>
#include <utility>

#include "dnscrypt/dns_crypt_client.h"
#include "common/net_utils.h"
#include "common/utils.h"
#include "dnscrypt/dns_crypt_consts.h"
#include "dnscrypt/dns_crypt_ldns.h"
#include "dnscrypt/dns_crypt_utils.h"
#include "dnsstamp/dns_stamp.h"

namespace ag::dnscrypt {

/**
 * Adjusts the maximum payload size advertised in queries sent to upstream servers
 * See https://github.com/jedisct1/dnscrypt-proxy/blob/master/dnscrypt-proxy/plugin_get_set_payload_size.go
 * See here also: https://github.com/jedisct1/dnscrypt-proxy/issues/667
 */
static void ldns_pkt_adjust_payload_size(ldns_pkt &msg) {
    size_t original_max_payload_size = LDNS_MIN_BUFLEN - QUERY_OVERHEAD;
    if (ldns_pkt_edns(&msg) && ldns_pkt_edns_version(&msg) == 0) {
        original_max_payload_size = std::max(MAX_DNS_UDP_SAFE_PACKET_SIZE - QUERY_OVERHEAD, original_max_payload_size);
    }
    size_t max_payload_size = std::min(MAX_DNS_PACKET_SIZE - QUERY_OVERHEAD, original_max_payload_size);
    if (max_payload_size > LDNS_MIN_BUFLEN) {
        ldns_pkt_set_edns_udp_size(&msg, max_payload_size);
    }
}

Client::Client(bool adjust_payload_size)
        : Client(DEFAULT_PROTOCOL, adjust_payload_size) {
}

Client::Client(utils::TransportProtocol protocol, bool adjust_payload_size)
        : m_protocol(protocol)
        , m_adjust_payload_size(adjust_payload_size) {
}

Client::DialResult Client::dial(std::string_view stamp_str, Millis timeout, const SocketFactory *socket_factory,
        SocketFactory::SocketParameters socket_parameters) const {
    static constexpr utils::MakeError<DialResult> make_error;
    auto [stamp, stamp_err] = ServerStamp::from_string(stamp_str);
    if (stamp_err) {
        return make_error(std::move(stamp_err));
    }
    if (stamp.proto != StampProtoType::DNSCRYPT) {
        return make_error("Stamp is not for a DNSCrypt server");
    }
    return dial(stamp, timeout, socket_factory, std::move(socket_parameters));
}

Client::DialResult Client::dial(const ServerStamp &stamp, Millis timeout, const SocketFactory *socket_factory,
        SocketFactory::SocketParameters socket_parameters) const {
    static constexpr utils::MakeError<DialResult> make_error;
    ServerInfo local_server_info{};
    if (crypto_box_keypair(local_server_info.m_public_key.data(), local_server_info.m_secret_key.data()) != 0) {
        return make_error("Can not generate keypair");
    }
    // Set the provider properties
    local_server_info.m_server_public_key = stamp.server_pk;
    local_server_info.m_server_address = stamp.server_addr_str;
    if (SocketAddress addr = utils::str_to_socket_address(local_server_info.m_server_address); addr.port() == 0) {
        local_server_info.m_server_address = AG_FMT("{}:{}", addr.host_str(), DEFAULT_PORT);
    }
    local_server_info.m_provider_name = stamp.provider_name;
    if (local_server_info.m_provider_name.empty()) {
        return make_error("Provider name is empty");
    }
    if (local_server_info.m_provider_name.back() != '.') {
        local_server_info.m_provider_name.push_back('.');
    }
    socket_parameters.proto = m_protocol;
    // Fetch the certificate and validate it
    auto [cert_info, rtt, err]
            = local_server_info.fetch_current_dnscrypt_cert(timeout, socket_factory, std::move(socket_parameters));
    if (err) {
        return make_error(std::move(err));
    }
    local_server_info.m_server_cert = cert_info;
    return {std::move(local_server_info), rtt, std::nullopt};
}

Client::ExchangeResult Client::exchange(ldns_pkt &message, const ServerInfo &local_server_info, Millis timeout,
        const SocketFactory *socket_factory, SocketFactory::SocketParameters socket_parameters) const {
    static constexpr utils::MakeError<ExchangeResult> make_error;
    utils::Timer timer;
    if (m_adjust_payload_size) {
        ldns_pkt_adjust_payload_size(message);
    }
    auto [query, create_ldns_buffer_err] = create_ldns_buffer(message);
    if (create_ldns_buffer_err) {
        return make_error(std::move(create_ldns_buffer_err));
    }
    auto [encrypted_query, client_nonce, encrypt_err] = local_server_info.encrypt(
            m_protocol, Uint8View(ldns_buffer_begin(query.get()), ldns_buffer_position(query.get())));
    if (encrypt_err) {
        return make_error(std::move(encrypt_err));
    }
    ldns_buffer encrypted_query_buffer = {};
    ldns_buffer_new_frm_data(&encrypted_query_buffer, encrypted_query.data(), encrypted_query.size());
    ldns_buffer_set_position(&encrypted_query_buffer, encrypted_query.size());
    socket_parameters.proto = m_protocol;
    auto [encrypted_response, exchange_rtt, exchange_err]
            = dns_exchange(timeout, utils::str_to_socket_address(local_server_info.m_server_address),
                    encrypted_query_buffer, socket_factory, std::move(socket_parameters));
    free(ldns_buffer_export(&encrypted_query_buffer));
    if (exchange_err) {
        return make_error(std::move(exchange_err));
    }
    // Reading the response
    // In case if the server local_server_info is not valid anymore (for instance, certificate was rotated)
    // the read operation will most likely time out.
    // This might be a signal to re-dial for the server certificate.
    auto [decrypted, decrypt_err]
            = local_server_info.decrypt(Uint8View(encrypted_response.data(), encrypted_response.size()),
                    Uint8View(client_nonce.data(), client_nonce.size()));
    if (decrypt_err) {
        return make_error(std::move(decrypt_err));
    }
    auto [reply_pkt_unique_ptr, reply_err] = create_ldns_pkt(decrypted.data(), decrypted.size());
    if (reply_err) {
        return make_error(std::move(reply_err));
    }
    auto rtt = timer.elapsed<Millis>();
    return {std::move(reply_pkt_unique_ptr), rtt, std::nullopt};
}

} // namespace ag::dnscrypt
